package io.kiki.sba.registry.concurrent;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.Weigher;
import io.kiki.sba.registry.cache.CacheCleaner;

import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;

public class CachedExecutor<K, V> {
    private final Cache<K, V> cache;
    private final LongAdder hitCount = new LongAdder();
    private final LongAdder missingCount = new LongAdder();

    public CachedExecutor(long silentMs) {
        this(silentMs, false);
    }

    public CachedExecutor(long silentMs, boolean expireAfterAccess) {
        CacheBuilder builder = CacheBuilder.newBuilder();
        if (expireAfterAccess) {
            builder = builder.expireAfterAccess(silentMs, TimeUnit.MILLISECONDS);
        } else {
            builder = builder.expireAfterWrite(silentMs, TimeUnit.MILLISECONDS);
        }
        cache = builder.build();
        CacheCleaner.autoClean(cache, silentMs);
    }

    public CachedExecutor(long silentMs, long maxWeight, Weigher<K, V> weigher, boolean expireAfterAccess) {
        CacheBuilder builder = CacheBuilder.newBuilder().maximumWeight(maxWeight).weigher(weigher);
        if (expireAfterAccess) {
            builder = builder.expireAfterAccess(silentMs, TimeUnit.MILLISECONDS);
        } else {
            builder = builder.expireAfterWrite(silentMs, TimeUnit.MILLISECONDS);
        }
        cache = builder.build();
        CacheCleaner.autoClean(cache, silentMs);
    }

    public V execute(K key, Callable<V> callable) throws Exception {
        V v = cache.getIfPresent(key);
        if (v != null) {
            hitCount.increment();
            return v;
        }
        return cache.get(key, () -> {
            missingCount.increment();
            onMiss(key);
            return callable.call();
        });
    }

    protected void onMiss(K key) {
    }

    public void clean() {
        cache.invalidateAll();
    }

    public long getHitCount() {
        return hitCount.longValue();
    }

    public long getMissingCount() {
        return missingCount.longValue();
    }
}
